Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support SD3 #1374

Draft
wants to merge 313 commits into
base: dev
Choose a base branch
from
Draft

support SD3 #1374

wants to merge 313 commits into from

Conversation

kohya-ss
Copy link
Owner

@kohya-ss kohya-ss commented Jun 15, 2024

  • Replace SD3Tokenizer with the original CLIP-L/G/T5 tokenizers.
  • Extend the max token length to 256 for T5XXL.
  • Refactor caching for latents.
  • Refactor caching for Text Encoder outputs
  • Extract architecture-dependent parts from datasets.
  • Refactor SD/SDXL training scripts.
  • Caching attention mask etc.
  • Enable training for CLIP-L/G for SD3.
  • Add an option to use T5XXL from transformers (for fp8 quantized ver.)
  • Add attention mask for T5XXL embeds (?). https://www.reddit.com/r/StableDiffusion/comments/1e6k59c/solution_discovered_partially_implemented_for_sd3/
  • Sample images during training.
  • Cache Text Encoder outputs for sampling.
  • Update SD/SDXL sampling to use refactored Text Encoding etc.
  • Update gen_img.py to use refactored Text Encoding etc.
  • SD3 LoRA support.
  • SD3.5 support.
  • FLUX.1 fine tuning.
  • FLUX.1 LoRA support for FLUX.
  • FLUX.1 LoRA support for CLIP-L.
  • FLUX.1 masking for attention
  • FLUX.1 Sample image generation during training.
  • Update cache_latents.py and cache_text_encoder_outputs.py to support FLUX.1
  • Support .json metadata for FLUX.1 and SD3.
  • Add the captioning script with Florence-2 and/or JoyCaption.
  • Support the loss called 'prior preservation loss'.

@bghira
Copy link

bghira commented Jun 16, 2024

this is a chance to just use Diffusers modules instead of doing everything from scratch. why not take it?

@kohya-ss
Copy link
Owner Author

There are several reasons for this, but the biggest reason is that it is difficult to extend. For example, LoRA, custom ControlNet and Deep Shrink etc.

Also, considering the various processes in the training scripts, such as conditional loss, SNR, masked loss, etc., the training scripts need to be written from scratch.

@bghira
Copy link

bghira commented Jun 16, 2024

all of that is done via peft other than deepshrink but you can make a pipeline callback for that.

@bghira
Copy link

bghira commented Jun 16, 2024

i mean to use the sd3 transformer module from the diffusers project.

it is frustrating to see bespoke versions of things with unreadable comments always in this repository. can you at least leave better comments?

@kohya-ss
Copy link
Owner Author

I think transformer module should be extendable for the future. In addition, SD3 transformer is based on sd3-ref (Stability AI official repo), and modified by KBlueLeaf to support xformers etc. So it is prior to Diffusers, and not full scratch. I appreciate your understanding.

I will add better comments in future codes, including this PR.

@araleza
Copy link

araleza commented Jul 10, 2024

Hello, I have been trying out SD3 training. It seems to be working pretty well. 😊

One thing I noticed is that generation of sample images while training is not yet implemented. This made it hard for me to see how my SD3 training was going, and make adjustments.

Implementing full support for all the sample images was difficult, but I found a cheap way to get most features working, and now I have sample images working again. This code is not properly integrated with the usual sample image generation code, but if people want to use it while they wait for a real well-integrated implementation, it does the basics of what's needed.

Just go into your sd3_train.py file, and find this commented-out section:

                # sdxl_train_util.sample_images(
                #     accelerator,
                #     args,
                #     None,
                #     global_step,
                #     accelerator.device,
                #     vae,
                #     [tokenizer1, tokenizer2],
                #     [text_encoder1, text_encoder2],
                #     mmdit,
                # )

and replace that with this:

                # Generate sample images
                if args.sample_every_n_steps is not None and global_step % args.sample_every_n_steps == 0:
                    from sd3_minimal_inference import do_sample
                    from PIL import Image
                    import datetime
                    import numpy as np
                    import shlex
                    import random

                    assert args.save_t5xxl, "When generating sample images in SD3, --save_t5xxl parameter must be set"

                    with open(args.sample_prompts, 'r') as file:
                        lines = [line.strip() for line in file if line.strip()]

                    vae.to("cuda")
                    for line in lines:
                        logger.info(f"Generating image: {line}")

                        if line.find('--') != -1:
                            prompt = line[:line.find('--') - 1].strip()
                            line = line[line.find('--'):]
                        else:
                            prompt = line
                            line = ''

                        parser_s = argparse.ArgumentParser()
                        parser_s.add_argument("--w", type=int, action="store", default=1024, help="image width")
                        parser_s.add_argument("--h", type=int, action="store", default=1024, help="image height")
                        parser_s.add_argument("--s", type=int, action="store", default=30,   help="sample steps")
                        parser_s.add_argument("--l", type=int, action="store", default=4,    help="CFG")
                        parser_s.add_argument("--d", type=int, action="store", default=random.randint(0, 2**32 - 1), help="seed")
                        prompt_args = shlex.split(line)
                        args_s = parser_s.parse_args(prompt_args)

                        # prepare embeddings
                        lg_out, t5_out, pooled = sd3_utils.get_cond(prompt, sd3_tokenizer, clip_l, clip_g, t5xxl) # +'ve prompt
                        cond = torch.cat([lg_out, t5_out], dim=-2), pooled

                        lg_out, t5_out, pooled = sd3_utils.get_cond("", sd3_tokenizer, clip_l, clip_g, t5xxl) # No -'ve prompt
                        neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled

                        latent_sampled = do_sample(
                            args_s.h, args_s.w, None, args_s.d, cond, neg_cond, mmdit, args_s.s, args_s.l, weight_dtype, accelerator.device
                        )

                        # latent to image
                        with torch.no_grad():
                            image = vae.decode(latent_sampled)
                        image = image.float()
                        image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
                        decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
                        decoded_np = decoded_np.astype(np.uint8)
                        out_image = Image.fromarray(decoded_np)

                        # save image
                        output_dir = os.path.join(args.output_dir, "sample")
                        os.makedirs(output_dir, exist_ok=True)
                        output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
                        out_image.save(output_path)

                    vae.to("cpu")

It supports a caption followed by the usual optional --w, --h, --s, --l, --d (for width, height, steps, cfg, and seed). It doesn't support negative captions, and it won't work right with captions longer than 75 tokens.

I'm finding sample image generation to be helpful. For example, I notice that most of my sample output images start off by looking brighter than expected (with white or bright backgrounds). Edit: Might have been my cfg of 7.5; SD3 seems to want lower cfgs. I had to push the sample count up as the cfg was lowered. Image quality still seems poor though, compared to what some people are getting out of SD3.

@araleza
Copy link

araleza commented Jul 10, 2024

Think I've found an issue that's causing the poor quality SD3 samples. The do_sample() function is not filling in the shift parameter that's required by SD3, and it's defaulting to 1.0 instead of the recommended 3.0:

class ModelSamplingDiscreteFlow:
    """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""

    def __init__(self, shift=1.0):
        self.shift = shift
        timesteps = 1000
        self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1))

From sd-script's sd3_minimal_inference.py function, do_sample()

    model_sampling = sd3_utils.ModelSamplingDiscreteFlow()

From the SD3 paper:
image

The paper also seems to say that these shifts to the sigmas should be present during training. Are these maybe missing too, @kohya-ss? (Edit: No, a shift value of 3.0 is already set up correctly during training)

@kohya-ss
Copy link
Owner Author

Think I've found an issue that's causing the poor quality SD3 samples. The do_sample() function is not filling in the shift parameter that's required by SD3, and it's defaulting to 1.0 instead of the recommended 3.0:

Thank you! I fixed it. The generated images seemed to be better now.

@kohya-ss
Copy link
Owner Author

I agree that the sample image generation is really useful. In my understanding, T5XXL is on CPU, so I wonder get_cond may take a long time. How much time it takes?

I think it might be necessary to get TE's output for the sampling prompt in advance, at the same time the TE caching. However, if T5XXL works on CPU with an acceptable time, the implementation of the sample generation will be much easier (like your implementation :) .

@bghira
Copy link

bghira commented Jul 11, 2024

it takes about 30-50 seconds to run T5 XL on the CPU, i think XXL is even worse latency for each embed

@araleza
Copy link

araleza commented Jul 11, 2024

I agree that the sample image generation is really useful. In my understanding, T5XXL is on CPU, so I wonder get_cond may take a long time. How much time it takes?

@kohya-ss, the calls to get_cond() only take around 2 seconds each on my machine. The whole sample image generation takes just 16 seconds per image for me, and I am still doing 80 sample steps for the images. :D

My PC is an ordinary (but good) home PC machine with a 13th gen Intel i7, and I've got 64 GB of CPU RAM. Perhaps the people finding the T5 XL to be very slow are running out of CPU memory and swapping the T5 XL out to disk without realizing? @bghira

@kohya-ss
Copy link
Owner Author

Thank you @bghira and @araleza ! I test with T5XXL on GPU, and it takes less than 2 seconds on GPU as araleza wrote, and it seems to require about 32GB additional main RAM... So practically, it may be needed to cache TE's output of the sample prompt.

@bghira
Copy link

bghira commented Jul 11, 2024

yes, and the text encoder being trained will cause the problem. but maybe they shouldn't be trained 🤷

@bghira
Copy link

bghira commented Jul 11, 2024

@araleza lol i'm running on 8x H100 system with more than 1.6TB of system memory and high-end EPYC

@araleza
Copy link

araleza commented Jul 11, 2024

Hello! So far, I've had to run SD3 training in full_bf16 mode, because I run out of (24GB) VRAM if I did not choose this option. I've now found a way to run training in full fp32.

This line of code pushes VRAM usage much higher:

            mmdit = accelerator.prepare(mmdit)

and then shortly after, the VRAM usage drops when the T5XXL is moved from GPU to CPU:

            t5xxl.to("cpu", dtype=torch.float32)

If I change these two code sections around, the peak VRAM usage is much lower. So that's switching the section that starts if args.cache_text_encoder_outputs: with the section that starts if args.deepspeed:.

It looks like training in fp32 mode may improve quality significantly, although it is slower.

I may be using more VRAM than needed because I'm not using Deepspeed yet. (I tried to use it, but I got an error message, and I haven't looked into what the cause of this error message is). But these code sections are still worth swapping to avoid the VRAM spike for people who are not using it.

(Edit: I'm not sure Deepspeed even works with SD3 just now, so maybe everyone with 24GB is currently running out of VRAM when trying to use fp32?)

@FurkanGozukara
Copy link

reducing peak vram is so important

@kohya-ss
Copy link
Owner Author

If I change these two code sections around, the peak VRAM usage is much lower. So that's switching the section that starts with the section that starts .if args.cache_text_encoder_outputs:``if args.deepspeed:

That's right. When we cache the Text Encoder outputs (it is necessary for now), Text Encoders can be moved to CPU before preparing MMDiT. I updated the script.

In my env, SD3 training works with the mixed precision with 24GB, with AdaFactor optimizer and gradient checkpointing. I believe it will not work without the mixed precision.

@araleza
Copy link

araleza commented Jul 12, 2024

In my env, SD3 training works with the mixed precision with 24GB, with AdaFactor optimizer and gradient checkpointing. I believe it will not work without the mixed precision.

Huh, this is confusing... 🤔 I've now been training SD3 with full fp32 for everything over the last day, and it all works great. The quality with fp32 is amazing compared to bf16. I'm also using Adafactor with the same optimizer settings as I usually use for sdxl (i.e. 'scale_parameter=False relative_step=False warmup_init=False').

I'm even using a batch size of 4, and I still have VRAM to spare:

|    0   N/A  N/A      8186      C   .../Dev/sd3/sd-scripts/venv/bin/python      20202MiB |

I don't have the --mixed_precision=bf16 or the --full_bf16 flags set (or any fp16 flags either).

@bghira
Copy link

bghira commented Jul 12, 2024

yes, but how high can the batch size go with mixed precision? 4 is very low, bordering on useless?

@araleza
Copy link

araleza commented Jul 12, 2024

I've been using batch size 4 for a long time - it seems pretty good for fine tuning to add a concept. Maybe people who want to do continued pretraining would want to use a higher batch size, but they'll have more VRAM than 24 GB.

I've just tried batch size 6 (with full fp32) now, and it works as well. Batch size 8 ran out of VRAM after around 50 steps.

@kohya-ss, here's my command line parameters if they're useful to you:

--pretrained_model_name_or_path="/home/ara/Dev/training/earthscape/kohya/dreambooth/at-step00008900.safetensors" --clip_l="/home/ara/Dev/sd3/clip_l.safetensors" --clip_g="/home/ara/Dev/sd3/clip_g.safetensors" --enable_bucket --min_bucket_reso=64 --max_bucket_reso=1024 --train_data_dir="/home/ara/Dev/training/earthscape/kohya/img" --resolution="1024,1024" --output_dir="/home/ara/Dev/training/earthscape/kohya/dreambooth" --caption_extension=".txt" --logging_dir="/home/ara/Dev/training/earthscape/kohya/log" --save_model_as=safetensors --lr_scheduler_num_cycles="20000" --max_data_loader_n_workers="0" --lr_scheduler="constant_with_warmup" --lr_warmup_steps="100" --max_train_steps="160000" --optimizer_type="Adafactor" --optimizer_args scale_parameter=False relative_step=False warmup_init=False --max_data_loader_n_workers="0" --bucket_reso_steps=32 --v_pred_like_loss="0.5" --save_every_n_steps="100" --save_last_n_steps="200" --gradient_checkpointing --sdpa --bucket_no_upscale --sample_sampler=k_dpm_2 --sample_prompts="/home/ara/Dev/training/earthscape/kohya/dreambooth/sample/prompt.txt" --sample_every_n_steps="100" --cache_latents --loss_type=huber --train_batch_size="4" --enable_wildcard --alpha_mask --cache_text_encoder_outputs  --cache_latents --cache_latents_to_disk --learning_rate=4e-7 --save_t5xxl

@kohya-ss
Copy link
Owner Author

@araleza Thank you! I've tested without the mixed precision (--mixed_precision no for accelerate, and removing --mixed_precision option for sd3_train), and it works!

Surprisingly, with batch_size=1, fp32 training seems to use less memory than bf16. I am wondering if there might be something wrong with the model and will investigate.

@FurkanGozukara
Copy link

Currently are we able to train clip text encoders? That lower ones not t5

I am guessing training model + clip text encoders would yield better but I didn't try yet I am still waiting main branch merge

@araleza
Copy link

araleza commented Jul 13, 2024

@FurkanGozukara, training the original text encoders (clip_l and clip_g) is not currently supported. That's because SD3 training currently requires caching of the text encoder outputs at the start of training, which means the text encoder weights cannot then be updated.

The reason for the forced caching is that running the T5XXL text encoder takes several seconds, assuming you can even fit it in CPU RAM. Kohya mentioned it takes 32 GB, which is more than some people have. Even if it fit in memory, it takes around 2 seconds to convert the captions into embeddings, which would have to happen on each training step.

The obvious solution would be to cache the T5XXL outputs only, allowing clip_l and clip_g to train. But code to cache one of the three text encoders but not the other two has not currently been written.

@FurkanGozukara
Copy link

@araleza thanks a lot for detailed explanation. I hope that code comes sooner

@sdbds
Copy link
Contributor

sdbds commented Jul 13, 2024

image

TEST RESULTS:
FP32 VS MixFP16

FP32 is much better and there's no difference in VRAM usage.
It seems like SDXL doesn't have that noticeable of a difference.

@bghira
Copy link

bghira commented Jul 13, 2024

fp32 should definitely use more vram than full bf16 / fp16 and if it isn't, there might be something wrong.

@bghira
Copy link

bghira commented Jul 20, 2024

make sure to apply t5 attention mask in the attentionprocessor @kohya-ss

@kohya-ss
Copy link
Owner Author

kohya-ss commented Dec 2, 2024

FLUX.1 ControlNet training is merged #1813. Thank you minux302 for the contribution!

Currently, 80GB VRAM is needed for 1024x1024 training. As soon as I have time, I will try using block swap for ControlNet to see if I can reduce the required VRAM.

@FurkanGozukara
Copy link

@kohya-ss amazing

any sample dataset and dataset toml file that we can take a look at for controlnet?

@kohya-ss
Copy link
Owner Author

kohya-ss commented Dec 2, 2024

flux_train_control_net.py now supports --blocks_to_swap. It should run with 16 or 24GB VRAM.

@kohya-ss
Copy link
Owner Author

kohya-ss commented Dec 2, 2024

any sample dataset and dataset toml file that we can take a look at for controlnet?

We can use fill50k dataset from the original ControlNet: https://huggingface.co/lllyasviel/ControlNet/tree/main/training

Places the captions as .txt files in target. The dataset config is something like this:

[general]
resolution = [1024, 1024]

[[datasets]]
batch_size = 1
enable_bucket = false

  [[datasets.subsets]]
  image_dir = "/path/to/fill50k/target"
  caption_extension = ".txt"
  conditioning_data_dir = "/path/to/fill50k/source"

I'll write more details when I have time.

@FurkanGozukara
Copy link

@kohya-ss thanks

now i looked that dataset

e.g.

image

this is both source and target

{"source": "source/19123.png", "target": "target/19123.png", "prompt": "sandy brown circle with moccasin background"}

if you can elaborate more that would be amazing thank you

@kohya-ss
Copy link
Owner Author

kohya-ss commented Dec 4, 2024

target/19123.png should be like this:
image

I think this is an appropriate image for the caption "sandy brown circle with moccasin background".

@FurkanGozukara
Copy link

@kohya-ss thank you so much I understand now

so this is training a canny controlnet right?

@kohya-ss
Copy link
Owner Author

kohya-ss commented Dec 4, 2024

so this is training a canny controlnet right?

This is just a toy dataset, it can only control the position of the circle with a condition image, and the color of the circle and background with a prompt...

@FurkanGozukara
Copy link

so this is training a canny controlnet right?

This is just a toy dataset, it can only control the position of the circle with a condition image, and the color of the circle and background with a prompt...

so are there any real controlnet dataset that we can take a look at? thank you

kohya-ss and others added 4 commits December 7, 2024 15:12
Workflow tests fixes and documentation
* Update sd3_train.py

* add freeze block lr

* Update train_util.py

* update

* Revert "add freeze block lr"

This reverts commit 8b16535.

# Conflicts:
#	library/train_util.py
#	sd3_train.py

* use same control net model path

* use controlnet_model_name_or_path
@kohya-ss
Copy link
Owner Author

kohya-ss commented Dec 7, 2024

The option to specify the model name for existing ControlNet model has been unified for each ControlNet training script. Please specify --controlnet_model_name_or_path. Thanks to sdbds!

@kohya-ss
Copy link
Owner Author

RAdamScheduleFree optimizer is now supported. Please update schedulefree to 1.4.

@FurkanGozukara
Copy link

any chance we could get pivotal tuning ? https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning

@araleza
Copy link

araleza commented Dec 18, 2024

I just wanted to report that after using it for a few days, the new schedule-free RAdam optimizer seems very strong. I thought it might not be that good as it just seemed to offer a better warmup (so better for a few iterations and then just the same as normal after that), but it seems to produce improved quality results even after that.

You can activate it with --optimizer_type radamschedulefree on the command line. Despite the documentation recommending a default learning rate of 2.5e-3, I'm finding 1e-5 to 3e-5 to be more appropriate, at least for training a LoRA on some of my datasets.

@bghira
Copy link

bghira commented Dec 19, 2024

i think every new option added has somebody claim its the best or strongest choice lol

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.